import json
from abc import ABC
from distutils.util import strtobool
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union

import yaml
from omagent_core.base import BotBase
from omagent_core.models.od.schemas import Target
from omagent_core.services.handlers.sql_data_handler import SQLDataHandler
from omagent_core.utils.error import VQLError
from omagent_core.utils.logger import logging
from omagent_core.utils.plot import Annotator
from PIL import Image
from pydantic import BaseModel, model_validator


class ArgSchema(BaseModel):
    """ArgSchema defines the tool input schema. Only support one layer definition. Please prevent using complex structure."""

    class Config:
        """Configuration for this pydantic object."""

        extra = "allow"
        arbitrary_types_allowed = True

    class ArgInfo(BaseModel):
        description: Optional[str]
        type: str = "str"
        enum: Optional[List] = None
        required: Optional[bool] = True

    @model_validator(mode="before")
    @classmethod
    def validate_all(cls, values):
        for key, value in values.items():
            if type(value) is str:
                values[key] = cls.ArgInfo(name=value)
            elif type(value) is dict:
                values[key] = cls.ArgInfo(**value)
            elif type(value) is cls.ArgInfo:
                pass
            else:
                raise ValueError(
                    "The arg type must be one of string, dict or self.ArgInfo."
                )
        return values

    @classmethod
    def from_file(cls, schema_file: Union[str, Path]):
        if type(schema_file) is str:
            schema_file = Path(schema_file)
        if schema_file.suffix == ".json":
            with open(schema_file, "r") as f:
                schema = json.load(f)
        elif schema_file.suffix == ".yaml":
            with open(schema_file, "r") as f:
                schema = yaml.load(f, Loader=yaml.FullLoader)
        else:
            raise ValueError("Only support json and yaml file.")
        return cls(**schema)

    def generate_schema(self) -> Union[dict, list]:
        required_args = []
        parameters = {}
        for key, value in self.model_dump(exclude_none=True).items():
            parameters[key] = value
            if parameters[key].pop("required"):
                required_args.append(key)
        return parameters, required_args

    def validate_args(self, args: dict) -> dict:
        if type(args) is not dict:
            raise ValueError(
                "ArgSchema validate only support dict, not {}".format(type(args))
            )
        new_args = {}
        required_fields = set(
            [k for k, v in self.model_dump().items() if v["required"]]
        )
        name_mapping = {
            "str": "string",
            "int": "integer",
            "float": "number",
            "bool": "boolean",
        }
        for name, value in args.items():
            if name not in self.model_dump():
                logging.warning(
                    "The input args includes an unnecessary parameter {}. Removed from the args.".format(
                        name
                    )
                )
                continue
            if name_mapping[type(value).__name__] == self.model_dump()[name]["type"]:
                if (
                    self.model_dump()[name]["enum"]
                    and value not in self.model_dump()[name]["enum"]
                ):
                    raise ValueError(
                        "The value of {} should be one of {}, but got {}".format(
                            name, str(self.model_dump()[name]["enum"]), value
                        )
                    )
                new_args[name] = value
            elif self.model_dump()[name]["type"] == "string":
                try:
                    new_args[name] = str(value)
                except:
                    raise ValueError(
                        "Parameter {} type expect a str value, but got a {} {}".format(
                            name, type(value), value
                        )
                    )
            elif self.model_dump()[name]["type"] == "integer":
                try:
                    new_args[name] = int(value)
                except:
                    raise ValueError(
                        "Parameter {} type expect an int value, but got a {} {}".format(
                            name, type(value), value
                        )
                    )
            elif self.model_dump()[name]["type"] == "number":
                try:
                    new_args[name] = float(value)
                except:
                    raise ValueError(
                        "Parameter {} type expect a float value, but got a {} {}".format(
                            name, type(value), value
                        )
                    )
            elif self.model_dump()[name]["type"] == "boolean":
                if type(value) is bool:
                    new_args[name] = value
                else:
                    try:
                        new_args[name] = strtobool(str(value))
                    except:
                        raise ValueError(
                            "Parameter {} type expect a boolean value, but got a {} {}".format(
                                name, type(value), value
                            )
                        )
            else:
                raise ValueError(
                    "Parameter {} type expect one of string, integer, number and boolean, but got a {} {}".format(
                        name, self.model_dump()[name]["type"], type(value), value
                    )
                )

        if required_fields - set(new_args.keys()):
            raise VQLError(
                "The required fields {} are missing.".format(
                    required_fields - set(new_args.keys())
                )
            )
        return new_args


class BaseTool(BotBase, ABC):
    description: str
    func: Optional[Callable] = None
    args_schema: Optional[ArgSchema]
    special_params: Dict = {}

    def model_post_init(self, __context: Any) -> None:
        for _, attr_value in self.__dict__.items():
            if isinstance(attr_value, BotBase):
                attr_value._parent = self

    @property
    def workflow_instance_id(self) -> str:
        if hasattr(self, "_parent"):
            return self._parent.workflow_instance_id
        return None

    @workflow_instance_id.setter
    def workflow_instance_id(self, value: str):
        if hasattr(self, "_parent"):
            self._parent.workflow_instance_id = value

    def _run(self, **input) -> str:
        """Implement this function or pass 'func' arg when initializing."""
        return self.func(**input)

    async def _arun(self, **input) -> str:
        """Implement this function or pass 'func' arg when initializing."""
        return await self.func(**input)

    def run(self, input: Any) -> str:
        if self.args_schema != None:
            if type(input) != dict:
                raise ValueError(
                    "The input type must be dict when args_schema is specified."
                )
            self.args_schema.validate_args(input)
        return self._run(**input, **self.special_params)

    async def arun(self, input: Any) -> str:
        if self.args_schema != None:
            if type(input) != dict:
                raise ValueError(
                    "The input type must be dict when args_schema is specified."
                )
            self.args_schema.validate_args(input)
        return await self._arun(**input, **self.special_params)

    def generate_schema(self):
        if not self.args_schema:
            return {
                "type": "function",
                "description": self.description,
                "function": {
                    "name": self.name,
                    "parameters": {
                        "type": "object",
                        "name": "input",
                        "required": ["input"],
                    },
                },
            }
        else:
            properties, required = self.args_schema.generate_schema()
            return {
                "type": "function",
                "function": {
                    "name": self.name,
                    "description": self.description,
                    "parameters": {
                        "type": "object",
                        "properties": properties,
                        "required": required,
                    },
                },
            }


class BaseModelTool(BaseTool, ABC):
    # data_handler: Optional[SQLDataHandler]

    def visual_prompting(
        self,
        image: Image.Image,
        annotation: List[Target],
        prompting_type: str = "label_on_img",
        include_labels: Union[List, set, tuple] = None,
        exclude_labels: Union[List, set, tuple] = None,
    ) -> List[Image.Image]:
        annotator = Annotator(image)
        for obj in annotation:
            if (exclude_labels is not None and obj.label in exclude_labels) or (
                include_labels is not None and obj.label not in include_labels
            ):
                continue
            if obj.bbox:
                annotator.box_label(obj.bbox, obj.label, color="red")
            # TODO: Add polygon support
        return annotator.result()

    def infer(self, images: List[Image.Image], kwargs) -> List[List[Target]]:
        """The model inference step. Only support OD type detection.

        Args:
            images (List[Image.Image]): The list of input images. Image should be PIL Image object.
            kwargs (dict): The additional arguments for the model.

        Returns:
            List[List[Target]]: The detection results.
        """

    def ainfer(self, images: List[Image.Image], kwargs) -> List[List[Target]]:
        """The async version of model inference step. Only support OD type detection.

        Args:
            images (List[Image.Image]): The list of input images. Image should be PIL Image object.
            kwargs (dict): The additional arguments for the model.

        Returns:
            List[List[Target]]: The detection results.
        """


class MemoryTool(BaseTool):
    memory_handler: Optional[SQLDataHandler]

    def generate_schema(self) -> dict:
        """Generate the data table schema in dict format.

        Returns:
            dict: The data table schema. Including the table name, and the name, data type and additional information of each column.
        """
        table = self.memory_handler.table
        schema = {"table_name": table.__tablename__, "columns": []}
        for column in table.__table__.columns:
            schema["columns"].append(
                {
                    "name": column.name,
                    "type": column.type.__visit_name__,
                    "info": column.info,
                }
            )
        return schema

    def generate_prompt(self):
        pass

    def _run(self):
        self.memory_handler.execute_sql()

    async def _arun(self):
        self.memory_handler.execute_sql()
